{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple automatic differentiation illustration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Union, List\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "np.set_printoptions(precision=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = 3\n",
    "a.__add__(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2 3 1 0]\n",
      "Addition using '__add__': [6 7 5 4]\n",
      "Addition using '+': [6 7 5 4]\n"
     ]
    }
   ],
   "source": [
    "a = np.array([2,3,1,0])\n",
    "\n",
    "print(a)\n",
    "print(\"Addition using '__add__':\", a.__add__(4))\n",
    "print(\"Addition using '+':\", a + 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "Numberable = Union[float, int]\n",
    "\n",
    "def ensure_number(num: Numberable):\n",
    "    if isinstance(num, NumberWithGrad):\n",
    "        return num\n",
    "    else:\n",
    "        return NumberWithGrad(num)        \n",
    "\n",
    "class NumberWithGrad(object):\n",
    "    \n",
    "    def __init__(self, \n",
    "                 num: Numberable,\n",
    "                 depends_on: List[Numberable] = None,\n",
    "                 creation_op: str = ''):\n",
    "        self.num = num\n",
    "        self.grad = None\n",
    "        self.depends_on = depends_on or []\n",
    "        self.creation_op = creation_op\n",
    "\n",
    "    def __add__(self, \n",
    "                other: Numberable):\n",
    "        return NumberWithGrad(self.num + ensure_number(other).num,\n",
    "                              depends_on = [self, ensure_number(other)],\n",
    "                              creation_op = 'add')\n",
    "    \n",
    "    def __mul__(self,\n",
    "                other: Numberable = None):\n",
    "\n",
    "        return NumberWithGrad(self.num * ensure_number(other).num,\n",
    "                              depends_on = [self, ensure_number(other)],\n",
    "                              creation_op = 'mul')\n",
    "    \n",
    "    def backward(self, backward_grad: Numberable = None):\n",
    "        if backward_grad is None: # first time calling backward\n",
    "            self.grad = 1\n",
    "        else: \n",
    "            # These lines allow gradients to accumulate.\n",
    "            # If the gradient doesn't exist yet, simply set it equal\n",
    "            # to backward_grad\n",
    "            if self.grad is None:\n",
    "                self.grad = backward_grad\n",
    "            # Otherwise, simply add backward_grad to the existing gradient\n",
    "            else:\n",
    "                self.grad += backward_grad\n",
    "        \n",
    "        if self.creation_op == \"add\":\n",
    "            # Simply send backward self.grad, since increasing either of these \n",
    "            # elements will increase the output by that same amount\n",
    "            self.depends_on[0].backward(self.grad)\n",
    "            self.depends_on[1].backward(self.grad)    \n",
    "\n",
    "        if self.creation_op == \"mul\":\n",
    "\n",
    "            # Calculate the derivative with respect to the first element\n",
    "            new = self.depends_on[1] * self.grad\n",
    "            # Send backward the derivative with respect to that element\n",
    "            self.depends_on[0].backward(new.num)\n",
    "\n",
    "            # Calculate the derivative with respect to the second element\n",
    "            new = self.depends_on[0] * self.grad\n",
    "            # Send backward the derivative with respect to that element\n",
    "            self.depends_on[1].backward(new.num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "a = NumberWithGrad(3)\n",
    "b = a * 4\n",
    "c = b + 3\n",
    "c.backward()\n",
    "print(a.grad) # as expected\n",
    "print(b.grad) # as expected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = NumberWithGrad(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = a * 4\n",
    "c = b + 3\n",
    "d = (a + 2)\n",
    "e = c * d \n",
    "e.backward() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "35"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.grad # as expected"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
